import copy
import itertools
import logging
import pathlib
from typing import List, Optional

import click
import yaml

BASE_PATH = pathlib.Path(__file__).parent.resolve()

logger = logging.getLogger(__name__)


class dotdict(dict):
    """dot.notation access to dictionary attributes"""

    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__


def resolve_cluster_config(cluster: str) -> str:
    if cluster == "dgxh100_eos":
        return "eos"
    if cluster == "dgxa100_dracooci":
        return "draco-oci-iad"
    if cluster == "dgxa100_dracooci-ord":
        return "draco-oci-ord"
    if cluster == "dgxh100_coreweave":
        return "coreweave"
    if cluster == "ghci":
        return "ghci"
    raise ValueError(f"Unknown cluster {cluster} provided.")


def resolve_artifact_config(cluster: str) -> str:
    if cluster == "dgxh100_eos":
        return "eos_lustre"
    if cluster == "dgxa100_dracooci":
        return "draco-oci_lustre"
    if cluster == "dgxa100_dracooci-ord":
        return "draco-oci-ord_lustre"
    if cluster == "dgxh100_coreweave":
        return "coreweave_lustre"
    raise ValueError(f"Unknown cluster {cluster} provided.")


def flatten_products(workload_manifest: dotdict) -> dotdict:
    """Flattens a nested dict of products"""
    workload_manifest.products = [
        dict(**dict(zip(inp.keys(), values)), **{"test_case": product["test_case"][0]})
        for product in (workload_manifest.products or [])
        if "products" in product
        for inp in product["products"]
        for values in itertools.product(*inp.values())
    ]

    return workload_manifest


def flatten_workload(workload_manifest: dotdict) -> List[dotdict]:
    """Flattens a workload with products into a list of workloads that don't have products."""
    workload_manifest = dict(workload_manifest)
    products = workload_manifest.pop("products")
    workload_manifests = []
    for product in products:
        workload = copy.deepcopy(workload_manifest)
        workload["spec"] = {k: v for k, v in workload["spec"].items() if k not in product.keys()}
        workload["spec"] = dict(**dict(workload["spec"].items()), **product)
        workload_manifests.append(dotdict(**workload))
    return workload_manifests


def set_build_dependency(workload_manifests: List[dotdict]) -> List[dotdict]:
    for workload_manifest in workload_manifests:
        workload_manifest.spec["build"] = workload_manifest.spec["build"].format(
            **dict(workload_manifest.spec)
        )
    return workload_manifests


def load_config(config_path: str) -> dotdict:
    """Loads and parses a yaml file into a JETWorkloadManifest"""
    with open(config_path) as stream:
        try:
            return dotdict(**yaml.safe_load(stream))
        except yaml.YAMLError as exc:
            raise exc


def load_and_flatten(config_path: str) -> List[dotdict]:
    """Wrapper function for doing all the fun at once."""
    return set_build_dependency(
        flatten_workload(flatten_products(load_config(config_path=config_path)))
    )


def filter_by_test_case(workload_manifests: List[dotdict], test_case: str) -> Optional[dotdict]:
    """Returns a workload with matching name. Raises an error if there no or more than a single workload."""
    workload_manifests = list(
        workload_manifest
        for workload_manifest in workload_manifests
        if workload_manifest["spec"]["test_case"] == test_case
    )

    if len(workload_manifests) > 1:
        logger.info("Duplicate test_case found!")
        return None

    if len(workload_manifests) == 0:
        logger.info("No test_case found!")
        return None

    return workload_manifests[0]


def filter_by_scope(workload_manifests: List[dotdict], scope: str) -> List[dotdict]:
    """Returns all workload with matching scope."""
    workload_manifests = list(
        workload_manifest
        for workload_manifest in workload_manifests
        if workload_manifest.spec["scope"] == scope
    )

    if len(workload_manifests) == 0:
        logger.info("No test_case found!")
        return []

    return workload_manifests


def filter_by_environment(workload_manifests: List[dotdict], environment: str) -> List[dotdict]:

    workload_manifests_copy = list(
        workload_manifest
        for workload_manifest in workload_manifests.copy()
        if (
            hasattr(dotdict(**workload_manifest["spec"]), "environment")
            and workload_manifest["spec"]["environment"] == environment
        )
    )

    if len(workload_manifests_copy) == 0:
        logger.info("No test_case found!")
        return []

    return workload_manifests_copy


def filter_by_platform(workload_manifests: List[dotdict], platform: str) -> List[dotdict]:
    workload_manifests = list(
        workload_manifest
        for workload_manifest in workload_manifests
        if (
            hasattr(dotdict(**workload_manifest["spec"]), "platforms")
            and workload_manifest.spec["platforms"] == platform
        )
    )

    if len(workload_manifests) == 0:
        logger.info("No test_case found!")
        return []

    return workload_manifests


def filter_by_model(workload_manifests: List[dotdict], model: str) -> List[dotdict]:
    """Returns all workload with matching model."""
    workload_manifests = list(
        workload_manifest
        for workload_manifest in workload_manifests
        if workload_manifest.spec["model"] == model
    )

    if len(workload_manifests) == 0:
        logger.info("No test_case found!")
        return []

    return workload_manifests


def filter_by_tag(workload_manifests: List[dotdict], tag: str) -> List[dotdict]:
    """Returns all workload with matching tag."""
    workload_manifests = list(
        workload_manifest
        for workload_manifest in workload_manifests
        if hasattr(dotdict(**workload_manifest["spec"]), "tag")
        and workload_manifest["spec"]["tag"] == tag
    )

    if len(workload_manifests) == 0:
        logger.info("No test_case found!")
        return []

    return workload_manifests


def filter_by_test_cases(workload_manifests: List[dotdict], test_cases: str) -> List[dotdict]:
    """Returns a workload with matching name. Raises an error if there no or more than a single workload."""
    workload_manifests = list(
        workload_manifest
        for workload_manifest in workload_manifests
        for test_case in test_cases.split(",")
        if workload_manifest["spec"]["test_case"] == test_case
    )

    if len(workload_manifests) == 0:
        logger.info("No test_case found!")
        return []

    return workload_manifests


def load_workloads(
    container_tag: str,
    n_repeat: int = 1,
    time_limit: int = 1800,
    tag: Optional[str] = None,
    environment: Optional[str] = None,
    platform: Optional[str] = None,
    test_cases: str = "all",
    scope: Optional[str] = None,
    model: Optional[str] = None,
    test_case: Optional[str] = None,
    container_image: Optional[str] = None,
    record_checkpoints: Optional[str] = None,
) -> List[dotdict]:
    """Return all workloads from disk that match scope and platform."""
    recipes_dir = BASE_PATH / ".." / "recipes"
    local_dir = BASE_PATH / ".." / "local_recipes"

    workloads: List[dotdict] = []
    build_workloads: List = []
    for file in list(recipes_dir.glob("*.yaml")) + list(local_dir.glob("*.yaml")):
        workloads += load_and_flatten(config_path=str(file))
        if file.stem.startswith("_build"):
            build_workloads.append(load_config(config_path=str(file)))

    if scope:
        workloads = filter_by_scope(workload_manifests=workloads, scope=scope)

    if workloads and environment:
        workloads = filter_by_environment(workload_manifests=workloads, environment=environment)

    if workloads and model:
        workloads = filter_by_model(workload_manifests=workloads, model=model)

    if workloads and tag:
        workloads = filter_by_tag(workload_manifests=workloads, tag=tag)

    if workloads and platform:
        workloads = filter_by_platform(workload_manifests=workloads, platform=platform)

    if workloads and test_cases != "all":
        workloads = filter_by_test_cases(workload_manifests=workloads, test_cases=test_cases)

    if workloads and test_case:
        workloads = [filter_by_test_case(workload_manifests=workloads, test_case=test_case)]

    if not workloads:
        return []

    for workload in list(workloads):
        for build_workload in build_workloads:
            if (
                workload.spec["build"] == build_workload.spec["name"]
            ) and build_workload not in workloads:
                container_image = container_image or build_workload.spec["source"]["image"]
                build_workload.spec["source"]["image"] = f"{container_image}:{container_tag}"
                workloads.append(build_workload)

        workload.spec["n_repeat"] = n_repeat
        workload.spec["time_limit"] = time_limit
        workload.spec["artifacts"] = {
            key: value.replace(r"{platforms}", workload.spec["platforms"])
            for key, value in (
                workload.spec["artifacts"].items()
                if "artifacts" in workload.spec and workload.spec["artifacts"] is not None
                else {}
            )
        }

        if record_checkpoints == "true":
            workload.outputs = [
                {
                    "type": "artifact",
                    "key": f"unverified/model/mcore-ci/{container_tag}/{{model}}/{{name}}",
                    "subdir": "checkpoints",
                    "name": r"{model}/{name}",
                    "description": r"Checkpoint of {model}/{name}",
                    "pic": {"name": "Mcore CI", "email": "okoenig@nvidia.com"},
                    "labels": {"origin": "ADLR/Megatron-LM"},
                }
            ]
    return workloads


@click.command()
@click.option("--model", required=False, type=str, default=None, help="Model to select")
@click.option("--test-case", required=False, type=str, default=None, help="Test case to select")
def main(model: Optional[str], test_case: Optional[str]):
    workflows = load_workloads(container_tag="main", model=model, test_case=test_case)
    # Save workflows to YAML file
    output_file = "workflows.yaml"
    with open(output_file, "w") as f:
        yaml.dump([dict(workflow) for workflow in workflows], f)


if __name__ == "__main__":
    main()
